#!/usr/bin/env python3

from random import seed, randrange
from pathlib import Path

XLEN = 32
XLEN_MASK = (1 << XLEN) - 1

all_onehot = [1 << i for i in range(XLEN)]
all_onehot0 = [0] + all_onehot
all_onehot0_neg = all_onehot0 + [~i & XLEN_MASK for i in all_onehot0]
all_shamt = list(range(XLEN))

seed(XLEN_MASK // 29)
RAND_COUNT = 20
def get_random():
	return [randrange(0, 1 << XLEN) for i in range(RAND_COUNT)]

# Lists of instructions and their test inputs

instr_one_operand = [
	("clz"    , [*all_onehot0_neg, *get_random()]),
	("cpop"   , [*all_onehot0_neg, *get_random()]),
	("ctz"    , [*all_onehot0_neg, *get_random()]),
	("orc.b"  , [*all_onehot0_neg, *get_random()]),
	("rev8"   , [*all_onehot0_neg, *get_random()]),
	("sext.b" , [*all_onehot0_neg, *get_random()]),
	("sext.h" , [*all_onehot0_neg, *get_random()]),
	("zext.h" , [*all_onehot0_neg, *get_random()]),
	("zip"    , [*all_onehot0_neg, *get_random()]),
	("unzip"  , [*all_onehot0_neg, *get_random()]),
	("brev8"   , [*all_onehot0_neg, *get_random()]),
]

instr_reg_reg = [
	("sh1add" , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("sh2add" , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("sh3add" , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("andn"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("max"    , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("maxu"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("min"    , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("minu"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("orn"    , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("rol"    , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("ror"    , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("xnor"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("clmul"  , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("clmulh" , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("clmulr" , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("bclr"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("bext"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("binv"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("bset"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg, *all_shamt, *get_random()]),
	("pack"   , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg,             *get_random()]),
	("packh"  , [*all_onehot0_neg, *get_random()], [*all_onehot0_neg,             *get_random()]),
]

instr_reg_imm = [
	("rori"  , [*all_onehot0_neg, *get_random()], all_shamt),
	("bclri" , [*all_onehot0_neg, *get_random()], all_shamt),
	("bexti" , [*all_onehot0_neg, *get_random()], all_shamt),
	("binvi" , [*all_onehot0_neg, *get_random()], all_shamt),
	("bseti" , [*all_onehot0_neg, *get_random()], all_shamt),
]

# Generate input vector programs

prolog = """
.option norelax

// Automatically-generated test vector. Don't edit.

#define IO_BASE 0x80000000
#define IO_PRINT_CHAR (IO_BASE + 0x0)
#define IO_PRINT_U32  (IO_BASE + 0x4)
#define IO_EXIT       (IO_BASE + 0x8)

.section .vectors
__initial_mtvec:
	li a0, IO_EXIT
	li a1, -1
	sw a1, (a0)

.section .text

.global _start
_start:
	// Should be the initial value anyways, but just make sure:
	la a0, __initial_mtvec
	csrw mtvec, a0

	la sp, test_signature_start

"""

interlude = """

// Hazard3 sim returns exit status if you do this:

test_end:
	li a0, IO_EXIT
	sw zero, (a0)

// Catch other simulators with or without debug support:

	la a0, it_dead
	csrw mtvec, a0
.p2align 2
it_dead:
	ebreak
	j it_dead

.section .testdata, "wa"

.p2align 2
.global test_signature_start
test_signature_start:
"""

epilog = """
.global test_signature_end
test_signature_end:
"""

def sanitise(name):
	return "".join(c if ("_" + c).isidentifier() else "_" for c in name)

Path("test").mkdir(exist_ok = True)
Path("refgen").mkdir(exist_ok = True)

for instr, data in instr_one_operand:
	with open(f"test/{sanitise(instr)}.S", "w") as f:
		f.write(prolog)
		i = 0
		for d in data:
			f.write(f"test{i}:\n")
			f.write(f"\tli a1, 0x{d:08x}\n")
			f.write(f"\t{instr} a0, a1\n")
			f.write( "\tsw a0, (sp)\n")
			f.write(f"\taddi sp, sp, {XLEN // 8}\n\n")
			i = i + 1
		f.write(interlude)
		# Label the output locations for easier debug
		i = 0
		for d in data:
			f.write(f"test{i}_{sanitise(instr)}_{d:08x}:\n")
			f.write("\t.word 0\n" if XLEN == 32 else "\tdword 0\n")
			i = i + 1
		f.write(epilog)

for instr, rs1_list, rs2_list in instr_reg_reg:
	with open(f"test/{sanitise(instr)}.S", "w") as f:
		f.write(prolog)
		i = 0
		for rs1 in rs1_list:
			for rs2 in rs2_list:
				f.write(f"test{i}:\n")
				f.write(f"\tli a1, 0x{rs1:08x}\n")
				f.write(f"\tli a2, 0x{rs2:08x}\n")
				f.write(f"\t{instr} a0, a1, a2\n")
				f.write( "\tsw a0, (sp)\n")
				f.write(f"\taddi sp, sp, {XLEN // 8}\n\n")
				i = i + 1
		f.write(interlude)
		i = 0
		for rs1 in rs1_list:
			for rs2 in rs2_list:
				f.write(f"test{i}_{sanitise(instr)}_{rs1:08x}__{rs2:08x}:\n")
				f.write("\t.word 0\n" if XLEN == 32 else "\tdword 0\n")
				i = i + 1
		f.write(epilog)

for instr, rs1_list, imm_list in instr_reg_imm:
	with open(f"test/{sanitise(instr)}.S", "w") as f:
		f.write(prolog)
		i = 0
		for rs1 in rs1_list:
			for imm in imm_list:
				f.write(f"test{i}:\n")
				f.write(f"\tli a1, 0x{rs1:08x}\n")
				f.write(f"\t{instr} a0, a1, 0x{imm:02x}\n")
				f.write( "\tsw a0, (sp)\n")
				f.write(f"\taddi sp, sp, {XLEN // 8}\n\n")
				i = i + 1
		f.write(interlude)
		i = 0
		for rs1 in rs1_list:
			for imm in imm_list:
				f.write(f"test{i}_{sanitise(instr)}_{rs1:08x}__{imm:02x}:\n")
				f.write("\t.word 0\n" if XLEN == 32 else "\tdword 0\n")
				i = i + 1
		f.write(epilog)

# Generate reference vector programs for running on spike + pk

# (I spent a while fighting spike to just be a processor + physical memory +
# some MMIO, so I could run the same binaries, but this is easier)

c_prolog = """
#include <stdio.h>

// Automatically-generated test vector. Don't edit.

int main() {
	unsigned int rd, rs1, rs2;
"""
c_output_result = '\tprintf("%08x\\n", rd);\n\n'

c_epilog = """
	return 0;
}
"""

for instr, data in instr_one_operand:
	with open(f"refgen/{sanitise(instr)}.c", "w") as f:
		f.write(c_prolog)
		for d in data:
			f.write(f"\trs1 = 0x{d:08x};\n");
			f.write(f'\tasm("{instr} %0, %1" : "=r" (rd) : "r" (rs1));\n')
			f.write(c_output_result)
		f.write(c_epilog)

for instr, rs1_list, rs2_list in instr_reg_reg:
	with open(f"refgen/{sanitise(instr)}.c", "w") as f:
		f.write(c_prolog)
		for rs1 in rs1_list:
			for rs2 in rs2_list:
				f.write(f"\trs1 = 0x{rs1:08x};\n");
				f.write(f"\trs2 = 0x{rs2:08x};\n");
				f.write(f'\tasm("{instr} %0, %1, %2" : "=r" (rd) : "r" (rs1), "r" (rs2));\n')
				f.write(c_output_result)
		f.write(c_epilog)

for instr, rs1_list, imm_list in instr_reg_imm:
	with open(f"refgen/{sanitise(instr)}.c", "w") as f:
		f.write(c_prolog)
		for rs1 in rs1_list:
			for imm in imm_list:
				f.write(f"\trs1 = 0x{rs1:08x};\n");
				f.write(f'\tasm("{instr} %0, %1, %2" : "=r" (rd) : "r" (rs1), "i" ({imm}));\n')
				f.write(c_output_result)
		f.write(c_epilog)
